from Network.network_utils import run_optimizer
from ActualCausal.Train.train_utils import compute_likelihood
from ActualCausal.Train.regularizers import apply_regularizers

def train_given_active(args, params, model, buffer, masks, form="full", log_batch=[], wrap_function=None, additional=[], itr_num=0, intermediate_logger = None):
    mask_form = "all_mask" if form == "all" else "mask"
    if type(masks) == str: masks = buffer.eval_binary # get the mask from the buffer, instead of passed in, if a name
    weights = buffer.norm_confidence if args.active.given_weighting else params.sample_active_weights
    for i in range(args.active.active_steps):
        batch, idxes = buffer.sample(args.train.batch_size, weights)
        batch = wrap_function(batch) if wrap_function is not None else batch
        mask = masks[idxes]

        result = model.infer(batch, batch.valid * mask, [mask_form], log_batch=log_batch, additional=additional)
        grad_variables = [result.full_active_input] if args.active.include_gradient else list()
        compute_models, optims = model.get_model_optim([form])
        optim, compute_model = optims[0], compute_models[0]

        loss = apply_regularizers(- result[mask_form].log_probs, args, params, model, batch, results=result[mask_form])
        result.reg_loss = loss
        result.gradients = run_optimizer(optim, compute_model, loss, grad_variables=grad_variables)
        if intermediate_logger is not None: intermediate_logger.log(itr_num * args.active.full_steps + i, {"given": result}, intermediate_name = "_given")
    return result